/*
    Copyright (C) 2022 Tenable, Inc.

	Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
    You may obtain a copy of the License at

		http://www.apache.org/licenses/LICENSE-2.0

	Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS,
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    See the License for the specific language governing permissions and
    limitations under the License.
*/

package vulnerability

import (
	"context"
	"fmt"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go/aws/request"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/ecr"
	"github.com/tenable/terrascan/pkg/iac-providers/output"
	"go.uber.org/zap"
)

const (
	ecrURL = "amazonaws.com"
)

var (
	invalidImageReferenceMsg = "invalid image reference %s "
	errorScanningMsg         = "error scanning image %s : %v"
)

// ECR  ecr container registry
type ECR struct {
	scanner ecrScanner
}

// scanner implementor for ecrScanner interface
type escanner struct{}

func init() {
	RegisterContainerRegistry("ecr", &ECR{
		scanner: escanner{},
	})
}

// ecrScanner  holds external ECR methods
type ecrScanner interface {
	newClient() *ecr.ECR
	describeImageScanFindingsWithContext(ctx context.Context, client *ecr.ECR, input *ecr.DescribeImageScanFindingsInput, opts ...request.Option) (*ecr.DescribeImageScanFindingsOutput, error)
	startImageScanWithContext(ctx context.Context, client *ecr.ECR, input *ecr.StartImageScanInput, opts ...request.Option) (*ecr.StartImageScanOutput, error)
	waitUntilImageScanComplete(client *ecr.ECR, input *ecr.DescribeImageScanFindingsInput) error
}

// newClient returns new ECR client
func (escanner) newClient() *ecr.ECR {
	session := session.Must(session.NewSession())
	// Create a ECR client from just a session.
	return ecr.New(session)
}

// describeImageScanFindingsWithContext returns image scan findings
func (escanner) describeImageScanFindingsWithContext(ctx context.Context, client *ecr.ECR, input *ecr.DescribeImageScanFindingsInput, opts ...request.Option) (*ecr.DescribeImageScanFindingsOutput, error) {
	return client.DescribeImageScanFindingsWithContext(ctx, input)
}

// waitUntilImageScanComplete use to wait until image scan is complete
func (escanner) waitUntilImageScanComplete(client *ecr.ECR, input *ecr.DescribeImageScanFindingsInput) error {
	return client.WaitUntilImageScanComplete(input)
}

// startImageScanWithContext start image vulnerability scan
func (escanner) startImageScanWithContext(ctx context.Context, client *ecr.ECR, input *ecr.StartImageScanInput, opts ...request.Option) (*ecr.StartImageScanOutput, error) {
	return client.StartImageScanWithContext(ctx, input)
}

// CheckRegistry verify provided image belongs to ecr registry
func (e *ECR) checkRegistry(image string) bool {
	domain := GetDomain(image)
	return strings.HasSuffix(domain, ecrURL)
}

// GetVulnerabilities - get vulnerabilities from ecr registry
func (e *ECR) getVulnerabilities(container output.ContainerDetails, options map[string]interface{}) (vulnerabilities []output.Vulnerability) {
	results, err := e.ScanImage(context.Background(), container.Image)
	if err != nil {
		zap.S().Errorf("error finding vulnerabilities for image %s : %v", container.Image, err)
		return
	}
	hasFindings := results != nil &&
		results.ImageScanFindings != nil &&
		results.ImageScanFindings.Findings != nil
	if hasFindings {
		for _, result := range results.ImageScanFindings.Findings {
			vulnerability := output.Vulnerability{}
			vulnerability.PrepareFromECRImageScan(result)
			vulnerabilities = append(vulnerabilities, vulnerability)
		}
	}
	return vulnerabilities
}

// ScanImage calles aws ecr api to get image scan details
func (e *ECR) ScanImage(ctx context.Context, image string) (*ecr.DescribeImageScanFindingsOutput, error) {
	client := e.scanner.newClient()

	imageDetails := ImageDetails{}
	imageDetails = GetImageDetails(image, imageDetails)

	if imageDetails.Tag == "" && imageDetails.Repository == "" {
		zap.S().Errorf(invalidImageReferenceMsg, image)
		return nil, fmt.Errorf(invalidImageReferenceMsg, image)
	}

	if imageDetails.Tag == "" && imageDetails.Digest == "" {
		imageDetails.Tag = defaultTagValue
	}

	return e.GetImageScanResult(ctx, client, image, imageDetails)

}

// StartImageScan starts the scan of provided image
func (e *ECR) StartImageScan(ctx context.Context, client *ecr.ECR, image string, imageDetails ImageDetails) error {
	imageIdentifier := ecr.ImageIdentifier{}
	if imageDetails.Digest != "" {
		imageIdentifier.ImageDigest = &imageDetails.Digest
	} else {
		imageIdentifier.ImageTag = &imageDetails.Tag
	}
	input := ecr.StartImageScanInput{
		ImageId:        &imageIdentifier,
		RepositoryName: &imageDetails.Repository,
	}

	if _, err := e.scanner.startImageScanWithContext(ctx, client, &input); err != nil {
		zap.S().Errorf(errorScanningMsg, image, err)
		return err
	}

	describeImageScanFindingsInput := ecr.DescribeImageScanFindingsInput{
		ImageId:        &imageIdentifier,
		RepositoryName: &imageDetails.Repository,
	}

	// wait until scan of image is complete
	if err := e.scanner.waitUntilImageScanComplete(client, &describeImageScanFindingsInput); err != nil {
		zap.S().Errorf(errorScanningMsg, image, err)
		return err
	}
	return nil
}

// GetImageScanResult get the scan result from ECR
func (e *ECR) GetImageScanResult(ctx context.Context, client *ecr.ECR, image string, imageDetails ImageDetails) (*ecr.DescribeImageScanFindingsOutput, error) {
	imageIdentifier := ecr.ImageIdentifier{}
	if imageDetails.Digest != "" {
		imageIdentifier.ImageDigest = &imageDetails.Digest
	} else {
		imageIdentifier.ImageTag = &imageDetails.Tag
	}
	describeImageScanFindingsInput := ecr.DescribeImageScanFindingsInput{
		ImageId:        &imageIdentifier,
		RepositoryName: &imageDetails.Repository,
	}
	results, err := e.scanner.describeImageScanFindingsWithContext(ctx, client, &describeImageScanFindingsInput)
	if err != nil {
		if _, ok := err.(*ecr.ScanNotFoundException); ok {
			err := e.StartImageScan(ctx, client, image, imageDetails)
			if err != nil {
				return results, err
			}
			return e.GetImageScanResult(ctx, client, image, imageDetails)
		}
		zap.S().Errorf(errorScanningMsg, image, err)
		return results, err
	}
	hasFindings := results != nil &&
		results.ImageScanStatus != nil &&
		results.ImageScanStatus.Status != nil &&
		results.ImageScanFindings != nil &&
		results.ImageScanFindings.ImageScanCompletedAt != nil

	if hasFindings {
		if strings.EqualFold(*results.ImageScanStatus.Status, ecr.ScanStatusComplete) &&
			(*results.ImageScanFindings.ImageScanCompletedAt).Before(time.Now().Add(24*time.Hour)) {
			return results, nil
		} else if strings.EqualFold(*results.ImageScanStatus.Status, ecr.ScanStatusFailed) {
			return results, nil
		} else if strings.EqualFold(*results.ImageScanStatus.Status, ecr.ScanStatusInProgress) {
			return e.GetImageScanResult(ctx, client, image, imageDetails)
		} else if strings.EqualFold(*results.ImageScanStatus.Status, ecr.ScanStatusComplete) &&
			!((*results.ImageScanFindings.ImageScanCompletedAt).Before(time.Now().Add(24 * time.Hour))) {
			err := e.StartImageScan(ctx, client, image, imageDetails)
			if err != nil {
				return results, err
			}
			return e.GetImageScanResult(ctx, client, image, imageDetails)
		}

	}
	return results, nil
}
